-
Notifications
You must be signed in to change notification settings - Fork 68
Re-implement FlashAttention with new Xe atoms #547
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
I will break up this large commit into self-contained smaller commits after review is complete. |
|
The following command encounters an accuracy issue (Disposition: Failed) with seq_len_kv=256
However, when seq_len_kv is changed to 512 or higher, the example passes successfully. |
@ClarkChin08 I pushed a patch to fix issues like this earlier today. I double-checked your test case, and it's passing on my system; can you double-check with the latest commit? |
af2f402 to
326669e
Compare
Yes, passed now. |
applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp
Outdated
Show resolved
Hide resolved
|
Note: the CI is currently failing with compile-time divide-by-zero errors, but I can't reproduce the errors locally with any compiler/compile flags. If anyone can, let me know. |
f767eb5 to
10b0c97
Compare
Didn't realize CI was merging branches into main prior to testing. Thanks to @rolandschulz for helping figure this out. Branch is rebased now and split into a logical set of patches. |
b0e30f4 to
7dd479b
Compare
7dd479b to
460d34a
Compare
2bb6829 to
9f74e54
Compare
This PR updates FlashAttention to the new copy/MMA atoms.
Changes:
Current status: prefill/decode examples working, similar/better performance to old examples.
Known issues:
Additional features (causal masking, variable sequence lengths, etc.) to be added later.
Reminder: the new atoms require a very recent driver due to necessary IGC fixes/enhancements. Recommended version: ci-comp_igc-30613.